{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a80d8df22739b896",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Basic Multi-agent Collaboration\n",
    "\n",
    "A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. \n",
    "\n",
    "One way to approach complicated tasks is through a \"divide-and-conquer\" approach: create a \"specialized agent\" for each task or domain and route tasks to the correct \"expert\". This means that each agent can become a sequence of LLM calls that chooses how to use a specific \"tool\".\n",
    "\n",
    "This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using Burr.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "41629c14988dec58",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:48:28.506330Z",
     "start_time": "2024-04-14T22:48:28.502236Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# %pip install -U burr[start] langchain-community langchain-core langchain-experimental openai"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "642649bc6414efb4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:03:15.394918Z",
     "start_time": "2024-04-14T22:03:15.392234Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Environment variables\n",
    "import os\n",
    "# Make sure TAVILY_API_KEY & OPENAI_API_KEY are set\n",
    "# os.environ['TAVILY_API_KEY'] = 'your_tavily_api_key' # get one at https://tavily.com\n",
    "# os.environ['OPENAI_API_KEY'] = 'your_openai_api_key' # get one at https://platform.openai.com"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "1bd6ddb8fb909d38",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:14:30.125534Z",
     "start_time": "2024-04-14T22:14:27.831337Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# import everything that you'll need\n",
    "import pprint\n",
    "import json\n",
    "from typing import Annotated, Any, Optional\n",
    "from uuid import UUID\n",
    "\n",
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "from langchain_core.callbacks import BaseCallbackHandler\n",
    "from langchain_core.messages import FunctionMessage, HumanMessage\n",
    "from langchain_core.outputs import LLMResult\n",
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.tools import tool\n",
    "from langchain_core.utils.function_calling import convert_to_openai_function\n",
    "from langchain_experimental.utilities import PythonREPL\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation\n",
    "\n",
    "from burr import core\n",
    "from burr.core import Action, State, action, default, expr\n",
    "from burr.lifecycle import PostRunStepHook\n",
    "from burr.tracking import client as burr_tclient\n",
    "from burr.visibility import ActionSpanTracer, TracerFactory"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "1ca3c55452268be9",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    " # Define the tools that the agents will use\n",
    "\n",
    "Here we construct the python objects that will be used as tools by our code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ed8e5eefad5726d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:14:30.132311Z",
     "start_time": "2024-04-14T22:14:30.127372Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "tavily_tool = TavilySearchResults(max_results=5)\n",
    "repl = PythonREPL()\n",
    "\n",
    "@tool\n",
    "def python_repl(code: Annotated[str, \"The python code to execute to generate your chart.\"]):\n",
    "    \"\"\"Use this to execute python code. If you want to see the output of a value,\n",
    "    you should print it out with `print(...)`. This is visible to the user.\"\"\"\n",
    "    try:\n",
    "        # Warning: This executes code locally, which can be unsafe when not sandboxed\n",
    "        result = repl.run(code)\n",
    "    except BaseException as e:\n",
    "        return f\"Failed to execute. Error: {repr(e)}\"\n",
    "    return f\"Succesfully executed:\\n```python\\n{code}\\n```\\nStdout: {result}\"\n",
    "\n",
    "tools = [tavily_tool, python_repl]\n",
    "tool_executor = ToolExecutor(tools)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "6126c2b461163f05",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Define the agents\n",
    "Our \"agents\" are effectively an execution of a series of LLM calls. \n",
    "In this example we use LCEL to orchestrate this series of LLM calls.\n",
    "\n",
    "Unfortunately LangChain & LCEL don't give you an easy way to figure out all the prompts used, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c37033a4a91bd2f3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:14:30.791938Z",
     "start_time": "2024-04-14T22:14:30.358694Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# The Agent that we'll use. Our agents here only differ by the system message passed in.\n",
    "def create_agent(llm, tools, system_message: str):\n",
    "    \"\"\"Helper function to create an agent with a system message and tools.\"\"\"\n",
    "    functions = [convert_to_openai_function(t) for t in tools]\n",
    "\n",
    "    prompt = ChatPromptTemplate.from_messages(\n",
    "        [\n",
    "            (\n",
    "                \"system\",\n",
    "                \"You are a helpful AI assistant, collaborating with other assistants.\"\n",
    "                \" Use the provided tools to progress towards answering the question.\"\n",
    "                \" If you are unable to fully answer, that's OK, another assistant with different tools \"\n",
    "                \" will help where you left off. Execute what you can to make progress.\"\n",
    "                \" If you or any of the other assistants have the final answer or deliverable,\"\n",
    "                \" prefix your response with FINAL ANSWER so the team knows to stop.\"\n",
    "                \" You have access to the following tools: {tool_names}.\\n{system_message}\",\n",
    "            ),\n",
    "            MessagesPlaceholder(variable_name=\"messages\"),\n",
    "        ]\n",
    "    )\n",
    "    prompt = prompt.partial(system_message=system_message)\n",
    "    prompt = prompt.partial(tool_names=\", \".join([tool.name for tool in tools]))\n",
    "    return prompt | llm.bind_functions(functions)\n",
    "\n",
    "def _exercise_agent(messages: list, sender: str, agent, name: str) -> dict:\n",
    "    \"\"\"Helper function to exercise the agent code.\"\"\"\n",
    "    result = agent.invoke({\"messages\": messages, \"sender\": sender})\n",
    "    # We convert the agent output into a format that is suitable to append to the global state\n",
    "    if isinstance(result, FunctionMessage):\n",
    "        pass\n",
    "    else:\n",
    "        result = HumanMessage(**result.dict(exclude={\"type\", \"name\"}), name=name)\n",
    "    return {\n",
    "        \"messages\": result,\n",
    "        # Since we have a strict workflow, we can\n",
    "        # track the sender so we know who to pass to next.\n",
    "        \"sender\": name,\n",
    "    }\n",
    "\n",
    "# Define the actual agents via langchain's LCEL\n",
    "llm = ChatOpenAI(model=\"gpt-4-1106-preview\")\n",
    "research_agent = create_agent(\n",
    "    llm,\n",
    "    [tavily_tool],\n",
    "    system_message=\"You should provide accurate data for the chart generator to use.\",\n",
    ")\n",
    "chart_agent = create_agent(\n",
    "    llm,\n",
    "    [python_repl],\n",
    "    system_message=\"Any charts you display will be visible by the user.\",\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "e0ce269d3e58716b",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Define the actions that map to agents\n",
    "We now then create specific actions that map to the agents we need for this not example.\n",
    "    \n",
    "We want a \"chart generator\" action, that will map to an agent that can generate a chart based on provided context/data.\n",
    "\n",
    "We want a \"researcher\" action, that will map to an agent that can search for information on a topic.\n",
    "\n",
    "We then want a \"tool_node\" action, that will run a tool as specified by the prior action, i.e. agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d262371331b8f996",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:28:18.243782Z",
     "start_time": "2024-04-14T22:28:18.231891Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "@action(reads=[\"messages\", \"sender\"], writes=[\"messages\", \"sender\"])\n",
    "def research_node(state: State) -> tuple[dict, State]:\n",
    "    # Research agent and node\n",
    "    result = _exercise_agent(state[\"messages\"], state[\"sender\"], research_agent, \"Researcher\")\n",
    "    return result, state.append(messages=result[\"messages\"]).update(sender=\"Researcher\")\n",
    "\n",
    "\n",
    "@action(reads=[\"messages\", \"sender\"], writes=[\"messages\", \"sender\"])\n",
    "def chart_node(state: State) -> tuple[dict, State]:\n",
    "    # Chart agent and node\n",
    "    result = _exercise_agent(\n",
    "        state[\"messages\"], state[\"sender\"], chart_agent, \"Chart Generator\"\n",
    "    )\n",
    "    return result, state.append(messages=result[\"messages\"]).update(sender=\"Chart Generator\")\n",
    "\n",
    "\n",
    "@action(reads=[\"messages\"], writes=[\"messages\"])\n",
    "def tool_node(state: State) -> tuple[dict, State]:\n",
    "    \"\"\"This runs tools in the graph\n",
    "\n",
    "    It takes in an agent action and calls that tool and returns the result.\"\"\"\n",
    "    messages = state[\"messages\"]\n",
    "    # Based on the continue condition\n",
    "    # we know the last message involves a function call\n",
    "    last_message = messages[-1]\n",
    "    # We construct an ToolInvocation from the function_call\n",
    "    tool_input = json.loads(last_message.additional_kwargs[\"function_call\"][\"arguments\"])\n",
    "    # We can pass single-arg inputs by value\n",
    "    if len(tool_input) == 1 and \"__arg1\" in tool_input:\n",
    "        tool_input = next(iter(tool_input.values()))\n",
    "    tool_name = last_message.additional_kwargs[\"function_call\"][\"name\"]\n",
    "    action = ToolInvocation(\n",
    "        tool=tool_name,\n",
    "        tool_input=tool_input,\n",
    "    )\n",
    "    # We call the tool_executor and get back a response\n",
    "    response = tool_executor.invoke(action)\n",
    "    # We use the response to create a FunctionMessage\n",
    "    function_message = FunctionMessage(\n",
    "        content=f\"{tool_name} response: {str(response)}\", name=action.tool\n",
    "    )\n",
    "    # We return a list, because this will get added to the existing list\n",
    "    return {\"messages\": [function_message]}, state.append(messages=function_message)\n",
    "\n",
    "\n",
    "@action(reads=[], writes=[])\n",
    "def terminal_step(state: State) -> tuple[dict, State]:\n",
    "    \"\"\"Terminal step we have here that does nothing, but it could\"\"\"\n",
    "    return {}, state"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "2f467fb7ae6360f9",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Define the Graph / Application\n",
    "With Burr we need to now construct our application, i.e. graph, by:\n",
    "\n",
    "1. Defining what the actions are and how to transition between them.\n",
    "2. Defining the initial state of the application. In our example this means we need to provide a \"query\" for the agents to work on.\n",
    "\n",
    "Because Burr comes with built in persistence, we can also load a prior execution and continue from \n",
    "any point in its history by specifying a `app_instance_id` and `sequence_number` when building the application."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "79b4924965e77be3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:27:20.323166Z",
     "start_time": "2024-04-14T22:27:20.319966Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Adjust these if you want to load a prior execution\n",
    "app_instance_id = None\n",
    "sequence_id = None\n",
    "project_name = \"demo_lcel-multi-agent\"\n",
    "\n",
    "# CHANGE THIS IF YOU WANT SOMETHING DIFFERENT!\n",
    "default_query = (\"Fetch the UK's GDP over the past 5 years, then draw a line graph of it. \"\n",
    "                 \"Once the python code has been written and the graph drawn, the task is complete.\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d8d69620528ea842",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:27:24.193962Z",
     "start_time": "2024-04-14T22:27:24.188457Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Determine initial state and entry point\n",
    "def default_state_and_entry_point(query: str = None) -> tuple[dict, str]:\n",
    "    \"\"\"Sets the default state & entry point\n",
    "    :param query: the query for the agents to work on.\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    return (\n",
    "        dict(\n",
    "            messages=[\n",
    "                HumanMessage(\n",
    "                    content=query\n",
    "                )\n",
    "            ],\n",
    "            sender=None,\n",
    "        ),\n",
    "        \"researcher\",\n",
    "    )\n",
    "\n",
    "if app_instance_id:\n",
    "    tracker = burr_tclient.LocalTrackingClient(project_name)\n",
    "    persisted_state = tracker.load(\"demo\", app_id=app_instance_id, sequence_no=sequence_id)\n",
    "    if not persisted_state:\n",
    "        print(f\"Warning: No persisted state found for app_id {app_instance_id}.\")\n",
    "        initial_state, entry_point = default_state_and_entry_point(default_query)\n",
    "    else:\n",
    "        initial_state = persisted_state[\"state\"]\n",
    "        # for now we need to manually deserialize LangChain messages into LangChain Objects\n",
    "        from langchain_core import messages\n",
    "\n",
    "        initial_state = initial_state.update(\n",
    "            messages=messages.messages_from_dict(persisted_state[\"state\"][\"messages\"])\n",
    "        )\n",
    "        entry_point = persisted_state[\"position\"]\n",
    "else:\n",
    "    initial_state, entry_point = default_state_and_entry_point(default_query)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5812de2d0cdc8f21",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:28:23.019998Z",
     "start_time": "2024-04-14T22:28:22.190999Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Build the application \n",
    "def build_application(state: dict, entry_point: str):\n",
    "    _app = (\n",
    "        core.ApplicationBuilder()\n",
    "        .with_state(**state)\n",
    "        .with_actions(\n",
    "            researcher=research_node,\n",
    "            charter=chart_node,\n",
    "            call_tool=tool_node,\n",
    "            terminal=terminal_step,\n",
    "        )\n",
    "        .with_transitions(\n",
    "            (\"researcher\", \"call_tool\", expr(\"'function_call' in messages[-1].additional_kwargs\")),\n",
    "            (\"researcher\", \"terminal\", expr(\"'FINAL ANSWER' in messages[-1].content\")),\n",
    "            (\"researcher\", \"charter\", default),\n",
    "            (\"charter\", \"call_tool\", expr(\"'function_call' in messages[-1].additional_kwargs\")),\n",
    "            (\"charter\", \"terminal\", expr(\"'FINAL ANSWER' in messages[-1].content\")),\n",
    "            (\"charter\", \"researcher\", default),\n",
    "            (\"call_tool\", \"researcher\", expr(\"sender == 'Researcher'\")),\n",
    "            (\"call_tool\", \"charter\", expr(\"sender == 'Chart Generator'\")),\n",
    "        )\n",
    "        .with_entrypoint(entry_point)\n",
    "        .with_tracker(project=project_name)\n",
    "        .build()\n",
    "    )\n",
    "    return _app\n",
    "app = build_application(initial_state, entry_point)\n",
    "app.visualize(\n",
    "    output_file_path=\"statemachine\", include_conditions=True, format=\"png\"\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "b05e65c77dcc38b5",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# open up the Burr UI to trace the execution\n",
    "In another terminal run:\n",
    "```bash\n",
    "burr\n",
    "```\n",
    "and then open up the browser to [http://localhost:7241](http://localhost:7241) to see the execution of the application when you exercise `.run()` below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9d0b8a14d4614069",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:30:27.699594Z",
     "start_time": "2024-04-14T22:30:27.696685Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# this will run until completion.\n",
    "last_action, last_result, last_state = app2.run(halt_after=[\"terminal\"])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2880bdfee354b7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:30:48.865592Z",
     "start_time": "2024-04-14T22:30:48.861637Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "source": [
    "pprint.pprint(last_state)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "dedde58d-f684-49d9-81a3-df52e75c447d",
   "metadata": {},
   "source": [
    "# Change the Query!\n",
    "Right now we provide the starting query as state. So we just create a new application by adjusting \n",
    "the initial state we provide."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8883f456-2217-4802-be0f-3e14cc1475bd",
   "metadata": {},
   "source": [
    "# Let's change the query\n",
    "initial_state, entry_point = default_state_and_entry_point(\"Fetch the USA's GDP over the past 5 years, then draw a line graph of it. \"\n",
    "                 \"Once the python code has been written and the graph drawn, the task is complete.\")\n",
    "app2 = build_application(initial_state, entry_point)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "7173415dbbb0b554",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-14T22:29:47.695377Z",
     "start_time": "2024-04-14T22:28:41.193974Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# this will run until completion.\n",
    "last_action, last_result, last_state = app2.run(halt_after=[\"terminal\"])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "54543e628379f27a",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "pprint.pprint(last_state)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e15b805-4feb-4541-8d2f-5075368bb293",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
